Skip to content

Fix TorchAO v1 group offloading with use_stream=True#14112

Closed
Sunt-ing wants to merge 2 commits into
huggingface:mainfrom
Sunt-ing:14
Closed

Fix TorchAO v1 group offloading with use_stream=True#14112
Sunt-ing wants to merge 2 commits into
huggingface:mainfrom
Sunt-ing:14

Conversation

@Sunt-ing

@Sunt-ing Sunt-ing commented Jul 2, 2026

Copy link
Copy Markdown

What does this PR do?

Refs #13281.

This PR is scoped to the legacy TorchAO int8 weight-only path that produces AffineQuantizedTensor (Int8WeightOnlyConfig(version=1)). The current Float8WeightOnlyConfig path reported in #13281, and the int8 compile path that uses TorchAO version=2, both support is_pinned() and pin_memory() on current main.

The streamed group-offload path keeps a CPU copy of each tensor and normally pins that copy before transferring a group back to the accelerator. AffineQuantizedTensor is still a TorchAO tensor subclass, so _to_cpu() must call tensor.cpu(), but its pinning ops raise NotImplementedError: ... aten.is_pinned.

This PR keeps pinned memory for tensors whose pinning ops work, including TorchAO v2 tensors, and falls back to the CPU copy only when a TorchAO tensor does not implement those pinning ops.

End-to-end reproduction

Environment: NVIDIA RTX 4090, torch==2.8.0+cu128, torchao==0.17.0.

The script uses the public tiny Flux pipeline, quantizes the transformer with Int8WeightOnlyConfig(version=1), enables pipe.transformer.enable_group_offload(..., use_stream=True), moves the remaining modules to CUDA, and runs pipe(...).

import numpy as np
import torch

import diffusers
from diffusers import DiffusionPipeline, TorchAoConfig
from diffusers.quantizers import PipelineQuantizationConfig
from torchao.quantization import Int8WeightOnlyConfig

print(f"diffusers_file={diffusers.__file__}")
print(f"torch={torch.__version__}")
print(f"cuda={torch.cuda.get_device_name(0)}")

quantization_config = PipelineQuantizationConfig(
    quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig(version=1))}
)
pipe = DiffusionPipeline.from_pretrained(
    "hf-internal-testing/tiny-flux-pipe",
    quantization_config=quantization_config,
    torch_dtype=torch.bfloat16,
)
pipe.set_progress_bar_config(disable=True)
pipe.transformer.enable_group_offload(
    onload_device=torch.device("cuda"),
    offload_device=torch.device("cpu"),
    offload_type="leaf_level",
    use_stream=True,
    non_blocking=True,
)

for name, component in pipe.components.items():
    if name != "transformer" and isinstance(component, torch.nn.Module):
        if torch.device(component.device).type == "cpu":
            component.to("cuda")

images = pipe(
    "a dog",
    num_inference_steps=2,
    max_sequence_length=16,
    height=32,
    width=32,
    output_type="np",
).images
arr = np.asarray(images)
print("RESULT=PASS")
print(f"output_shape={arr.shape}")
print(f"finite={np.isfinite(arr).all()}")
print(f"mean={arr.mean():.6f}")

Before:

diffusers_file=/root/autodl-tmp/code/diffusers-14-e2e-before/src/diffusers/__init__.py
torch=2.8.0+cu128
cuda=NVIDIA GeForce RTX 4090
RESULT=FAIL
exception:
Traceback (most recent call last):
  File "/root/autodl-tmp/code/diffusers-14-e2e-before/repro_torchao_stream_e2e.py", line 39, in main
    pipe.transformer.enable_group_offload(
  File "/root/autodl-tmp/code/diffusers-14-e2e-before/src/diffusers/models/modeling_utils.py", line 573, in enable_group_offload
    apply_group_offloading(
  File "/root/autodl-tmp/code/diffusers-14-e2e-before/src/diffusers/hooks/group_offloading.py", line 702, in apply_group_offloading
    _apply_group_offloading(module, config)
  File "/root/autodl-tmp/code/diffusers-14-e2e-before/src/diffusers/hooks/group_offloading.py", line 709, in _apply_group_offloading
    _apply_group_offloading_leaf_level(module, config)
  File "/root/autodl-tmp/code/diffusers-14-e2e-before/src/diffusers/hooks/group_offloading.py", line 827, in _apply_group_offloading_leaf_level
    group = ModuleGroup(
  File "/root/autodl-tmp/code/diffusers-14-e2e-before/src/diffusers/hooks/group_offloading.py", line 167, in __init__
    self.cpu_param_dict = self._init_cpu_param_dict()
  File "/root/autodl-tmp/code/diffusers-14-e2e-before/src/diffusers/hooks/group_offloading.py", line 189, in _init_cpu_param_dict
    cpu_param_dict[param] = self._to_cpu(param, self.low_cpu_mem_usage)
  File "/root/autodl-tmp/code/diffusers-14-e2e-before/src/diffusers/hooks/group_offloading.py", line 180, in _to_cpu
    return t if low_cpu_mem_usage else t.pin_memory()
  File "/root/autodl-tmp/code/diffusers-14-e2e-before/pydeps/torchao/utils.py", line 684, in _dispatch__torch_dispatch__
    raise NotImplementedError(
NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
EXIT_CODE=1

After:

diffusers_file=/root/autodl-tmp/code/diffusers-14-retarget-codex/src/diffusers/__init__.py
torch=2.8.0+cu128
cuda=NVIDIA GeForce RTX 4090
UserWarning: Config Deprecation: version 1 of Int8WeightOnlyConfig is deprecated and will no longer be supported in a future release
UserWarning: Deprecation: AffineQuantizedTensor is deprecated and will be removed in a future release of torchao
RESULT=PASS
output_shape=(1, 32, 32, 3)
finite=True
mean=0.504391

Additional checks:

python -m py_compile src/diffusers/hooks/group_offloading.py tests/quantization/torchao/test_torchao.py
git diff --check
ruff check src/diffusers/hooks/group_offloading.py tests/quantization/torchao/test_torchao.py
All checks passed!
python -m pytest tests/quantization/torchao/test_torchao.py::TorchAoTest::test_group_offloading_torchao_int8wo_v1 -q -rs
1 passed, 11 warnings in 34.10s
TestFluxTransformerTorchAoCompile._test_torch_compile_with_group_offload(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"], use_stream=True)
RESULT=PASS

Before submitting

Who can review?

cc @sayakpaul

@github-actions github-actions Bot added size/S PR with diff < 50 LOC fixes-issue tests hooks and removed size/S PR with diff < 50 LOC fixes-issue labels Jul 2, 2026
@sayakpaul

Copy link
Copy Markdown
Member

I ran pytest tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAoCompile -k "test_torchao_torch_compile_with_group_offload" with the following diff:

diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py
index 95b6b0fc6..ea81538c3 100644
--- a/tests/models/testing_utils/quantization.py
+++ b/tests/models/testing_utils/quantization.py
@@ -1377,7 +1377,7 @@ class TorchAoCompileTesterMixin(TorchAoConfigMixin, QuantizationCompileTesterMix
 
     @pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"])
     def test_torchao_torch_compile_with_group_offload(self, quant_type):
-        self._test_torch_compile_with_group_offload(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type])
+        self._test_torch_compile_with_group_offload(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type], use_stream=True)
 
 
 @is_gguf

and it passed. What am I missing?

@Sunt-ing

Sunt-ing commented Jul 3, 2026

Copy link
Copy Markdown
Author

Thanks @sayakpaul, you are not missing anything. Your test uses the TorchAO v2 path, and that path passes because the resulting Int8Tensor supports is_pinned() and pin_memory().

I narrowed this PR to the legacy Int8WeightOnlyConfig(version=1) path, which produces deprecated AffineQuantizedTensor. That tensor still raises NotImplementedError: ... aten.is_pinned; Int8WeightOnlyConfig(version=2) and Float8WeightOnlyConfig support the pinning ops in my check.

The follow-up patch no longer skips pinning for all TorchAO tensors. It keeps the normal pinned-memory path and only falls back to the CPU copy when a TorchAO tensor raises NotImplementedError from the pinning ops.

Checks
Int8WeightOnlyConfig(version=1)
tensor type: torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor
is_pinned: NotImplementedError: ... aten.is_pinned
pin_memory: NotImplementedError: ... aten.is_pinned

Int8WeightOnlyConfig(version=2)
tensor type: torchao.quantization.Int8Tensor
is_pinned: False
pin_memory: OK

Float8WeightOnlyConfig()
tensor type: torchao.quantization.Float8Tensor
is_pinned: False
pin_memory: OK
python -m pytest tests/quantization/torchao/test_torchao.py::TorchAoTest::test_group_offloading_torchao_int8wo_v1 -q -rs
1 passed, 11 warnings in 34.10s
TestFluxTransformerTorchAoCompile._test_torch_compile_with_group_offload(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"], use_stream=True)
RESULT=PASS
tiny Flux pipeline with Int8WeightOnlyConfig(version=1) + pipe.transformer.enable_group_offload(..., use_stream=True)
RESULT=PASS
output_shape=(1, 32, 32, 3)
finite=True
mean=0.504391

@Sunt-ing Sunt-ing changed the title Fix torchao group offloading with use_stream=True Fix TorchAO v1 group offloading with use_stream=True Jul 3, 2026
@github-actions github-actions Bot added the size/S PR with diff < 50 LOC label Jul 3, 2026
@sayakpaul

Copy link
Copy Markdown
Member

It is not recommended to use v1. If things work fine with v2, then I don't think any fixes (like the ones introduced in this PR) are needed at all.

@Sunt-ing

Sunt-ing commented Jul 3, 2026

Copy link
Copy Markdown
Author

Thanks @sayakpaul, that makes sense. Since this is now scoped to the deprecated v1 path and the current v2/Float8 paths work, I'm going to close this PR.

@Sunt-ing Sunt-ing closed this Jul 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

hooks size/S PR with diff < 50 LOC tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants